#pragma once

#include <functional>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <Service/ACLMap.h>
#include <Service/SessionManager.h>
#include <Service/WatchManager.h>
#include <Service/ThreadSafeQueue.h>
#include <Service/KeeperCommon.h>
#include <Service/formatHex.h>
#include <ZooKeeper/IKeeper.h>
#include <Poco/Logger.h>
#include <Common/ConcurrentBoundedQueue.h>
#include <Common/IO/Operators.h>
#include <Common/IO/WriteBufferFromString.h>
#include <Common/ThreadPool.h>
#include <common/logger_useful.h>
#include <shared_mutex>

namespace RK
{

/**
 * Represent an entry in data tree.
 */
struct KeeperNode
{
    using ChildrenSet = std::unordered_set<String>;

    String data;
    uint64_t acl_id = 0;

    bool is_ephemeral = false;
    bool is_sequential = false;

    Coordination::Stat stat{};
    ChildrenSet children;

    std::shared_ptr<KeeperNode> clone() const;
    std::shared_ptr<KeeperNode> cloneWithoutChildren() const;

    /// All stat for client should be generated by this function.
    /// This method will remove numChildren from persisted stat.
    Coordination::Stat statForResponse() const;

    bool operator==(const KeeperNode & rhs) const
    {
        return data == rhs.data && acl_id == rhs.acl_id && is_ephemeral == rhs.is_ephemeral && is_sequential == rhs.is_sequential
            && children == rhs.children;
    }
    bool operator!=(const KeeperNode & rhs) const { return !(rhs == *this); }
};

using KeeperNodePtr = std::shared_ptr<KeeperNode>;

struct KeeperNodeWithPath
{
    String path;
    KeeperNodePtr node;
};

/// KeeperNodeMap is a two-level unordered_map which is designed to reduce latency for unordered_map scaling.
/// It is not a thread-safe map. But it is accessed only in the request processor thread.
template <typename Value, unsigned NumBuckets>
class KeeperNodeMap
{
public:
    using Key = String;
    using ValuePtr = std::shared_ptr<Value>;
    using NestedMap = std::unordered_map<String, ValuePtr>;
    using Action = std::function<void(const String &, const ValuePtr &)>;

    class InnerMap
    {
    public:
        ValuePtr get(const String & key)
        {
            auto i = map.find(key);
            return (i != map.end()) ? i->second : nullptr;
        }

        template <typename T>
        bool emplace(const String & key, T && value)
        {
            return map.insert_or_assign(key, value).second;
        }

        bool erase(const String & key)
        {
            return map.erase(key);
        }

        size_t size() const
        {
            return map.size();
        }

        void clear()
        {
            map.clear();
        }

        void forEach(const Action & fn)
        {
            for (const auto & [key, value] : map)
                fn(key, value);
        }

        /// This method will destroy InnerMap thread safety property.
        /// Deprecated, please use forEach instead.
        NestedMap & getMap() { return map; }

    private:
        NestedMap map;
    };

private:
    inline InnerMap & mapFor(const String & key) { return buckets[hash(key) % NumBuckets]; }

    std::array<InnerMap, NumBuckets> buckets;
    std::hash<String> hash;
    std::atomic<size_t> node_count{0};

public:
    ValuePtr get(const String & key) { return mapFor(key).get(key); }
    ValuePtr at(const String & key) { return mapFor(key).get(key); }

    template <typename T>
    bool emplace(const String & key, T && value)
    {
        if (mapFor(key).emplace(key, std::forward<T>(value)))
        {
            node_count++;
            return true;
        }
        return false;
    }

    template <typename T>
    bool emplace(const String & key, T && value, UInt32 bucket_id)
    {
        if (buckets[bucket_id].emplace(key, std::forward<T>(value)))
        {
            node_count++;
            return true;
        }
        return false;
    }

    bool erase(String const & key)
    {
        if (mapFor(key).erase(key))
        {
            node_count--;
            return true;
        }
        return false;
    }

    size_t count(const String & key) { return get(key) != nullptr ? 1 : 0; }

    UInt32 getBucketIndex(const String & key) { return hash(key) % NumBuckets; }
    UInt32 getBucketNum() const { return NumBuckets; }

    InnerMap & getMap(const UInt32 & bucket_id) { return buckets[bucket_id]; }

    void clear()
    {
        for (auto & bucket : buckets)
            bucket.clear();
        node_count.store(0);
    }

    size_t size() const
    {
        return node_count.load();
    }
};

/// KeeperStore hold data tree, sessions, watches and auths. It is under state machine.
class KeeperStore
{
public:
    /// bucket num for KeeperNodeMap
    static constexpr int DATA_TREE_BUCKET_NUM = 16;
    using DataTree = KeeperNodeMap<KeeperNode, DATA_TREE_BUCKET_NUM>;

    using KeeperResponsesQueue = ThreadSafeQueue<ResponseForSession>;

    using SessionAndAuth = std::unordered_map<int64_t, Coordination::AuthIDs>;
    using Ephemerals = std::unordered_map<int64_t, std::unordered_set<String>>;

    /// Hold Edges in different Buckets based on the parent node's bucket number.
    /// It should be used when load snapshot to built node's childrenSet in parallel without lock.
    using Edge = std::pair<String, String>;
    using Edges = std::vector<Edge>;
    using BucketEdges = std::array<Edges, DATA_TREE_BUCKET_NUM>;
    using BucketNodes = std::array<std::vector<std::pair<String, std::shared_ptr<KeeperNode>>>, DATA_TREE_BUCKET_NUM>;

    explicit KeeperStore(int64_t dead_session_check_period_ms, const String & super_digest_ = "");

    /// process request
    void processRequest(
        ThreadSafeQueue<ResponseForSession> & responses_queue,
        const RequestForSession & request_for_session,
        std::optional<int64_t> new_last_zxid = {}, /// empty when we are converting zookeeper log to raftkeeper data.
        bool check_acl = true,
        bool ignore_response = false);

    /// Build children set after loading data from snapshot
    void buildChildrenSet(bool from_zk_snapshot = false);

    // Build children set for the nodes in specified bucket after load data from snapshot.
    void buildBucketChildren(const std::vector<BucketEdges> & all_objects_edges, UInt32 bucket_id);
    void fillDataTreeBucket(const std::vector<BucketNodes> & all_objects_nodes, UInt32 bucket_id);

    /// Clean ephemeral nodes, invoked when shutdown
    void finalize();

    /// Clear whole store and set to initial state.
    void reset();

    /// Used when creating snapshot
    std::shared_ptr<BucketNodes> dumpDataTree();

    int64_t getZxid() const
    {
        return zxid.load();
    }

    void setZxid(int64_t new_zxid)
    {
        zxid.store(new_zxid);
    }

    SessionAndAuth getSessionAndAuth() const
    {
        std::lock_guard lock(auth_mutex);
        return session_and_auth;
    }

    void addSessionAuth(int64_t session_id, const Coordination::AuthIDs & auth)
    {
        std::lock_guard lock(auth_mutex);
        session_and_auth[session_id] = std::move(auth);
    }

    std::vector<int64_t> getDeadSessions() const
    {
        return session_manager.getDeadSessions();
    }

    std::unordered_map<int64_t, int64_t> sessionToExpirationTime() const
    {
        return session_manager.sessionToExpirationTime();
    }

    inline void handleRemoteSession(int64_t session_id, int64_t expiration_time)
    {
        session_manager.handleRemoteSession(session_id, expiration_time);
    }

    inline bool containsSession(int64_t session_id) const
    {
        return session_manager.contains(session_id);
    }

    size_t getDataTreeBucketNum() const
    {
        return data_tree.getBucketNum();
    }

    inline KeeperNodePtr getNode(const String & path)
    {
        return data_tree.get(path);
    }

    inline bool exists(const String & path)
    {
        return data_tree.count(path);
    }

    inline void addNode(const String & path, KeeperNodePtr node)
    {
        data_tree.emplace(path, node);
    }

    inline void removeNode(const String & path)
    {
        data_tree.erase(path);
    }

    inline void addEphemeralNode(int64_t session_id, const String & path)
    {
        std::lock_guard lock(ephemerals_mutex);
        ephemerals[session_id].insert(path);
    }

    inline void removeEphemeralNode(int64_t session_id, const String & path)
    {
        std::lock_guard lock(ephemerals_mutex);
        ephemerals[session_id].erase(path);
    }

    const String & getSuperDigest() const
    {
        return super_digest;
    }


    /// Introspection functions mostly used in 4-letter commands ///

    uint64_t getNodesCount() const { return data_tree.size(); }
    uint64_t getApproximateDataSize() const;

    uint64_t getSessionWithEphemeralNodesCount() const
    {
        std::lock_guard lock(ephemerals_mutex);
        return ephemerals.size();
    }

    uint64_t getTotalEphemeralNodesCount() const;
    void dumpSessionsAndEphemerals(WriteBufferFromOwnString & buf) const;

    SessionManager::SessionAndTimeout getSessionAndTimeOut() const
    {
        return session_manager.getSessionAndTimeOut();
    }

    uint64_t getSessionID(uint64_t session_timeout_ms)
    {
        return session_manager.getSessionID(session_timeout_ms);
    }

    bool updateSessionTimeout(int64_t session_id, int64_t session_timeout_ms)
    {
        return session_manager.updateSessionTimeout(session_id, session_timeout_ms);
    }

    void addSessionID(int64_t session_id, int64_t session_timeout_ms)
    {
        session_manager.addSessionID(session_id, session_timeout_ms);
    }

    int64_t getSessionCount() const
    {
        return session_manager.getSessionCount();
    }

    int64_t getSessionIDCounter() const
    {
        return session_manager.getSessionIDCounter();
    }

    void setSessionIDCounter(int64_t counter)
    {
        session_manager.setSessionIDCounter(counter);
    }

    DataTree & getDataTree()
    {
        return data_tree;
    }

    inline size_t getBucketIndex(const String & path)
    {
        return data_tree.getBucketIndex(path);
    }

    ACLMap & getACLMap()
    {
        return acl_map;
    }

    void addACLs(uint64_t acls_id, const Coordination::ACLs & acls)
    {
        acl_map.addMapping(acls_id, acls);
    }

    const Ephemerals & getEphemerals() const
    {
        std::lock_guard lock(ephemerals_mutex);
        return ephemerals;
    }

    /// watch related functions

    uint64_t getWatchedPathsCount() const
    {
        return watch_manager.getWatchedPathsCount();
    }

    uint64_t getTotalWatchesCount() const
    {
        return watch_manager.getTotalWatchesCount();
    }

    uint64_t getSessionsWithWatchesCount() const
    {
        return watch_manager.getSessionsWithWatchesCount();
    }

    void dumpWatches(WriteBufferFromOwnString & buf) const
    {
        watch_manager.dumpWatches(buf);
    }

    void dumpWatchesByPath(WriteBufferFromOwnString & buf) const
    {
        watch_manager.dumpWatchesByPath(buf);
    }

    void initializeSystemNodes();

    mutable std::shared_mutex auth_mutex;
    SessionAndAuth session_and_auth;

    /// ACLMap for more compact ACLs storage inside nodes.
    ACLMap acl_map;

private:
    int64_t fetchAndGetZxid() { return zxid++; }
    void cleanEphemeralNodes(int64_t session_id, ThreadSafeQueue<ResponseForSession> & responses_queue, bool ignore_response);

    /// data tree
    DataTree data_tree;

    SessionManager session_manager;
    WatchManager watch_manager;

    /// all ephemeral nodes goes here
    Ephemerals ephemerals;
    mutable std::mutex ephemerals_mutex;

    /// Global transaction id, only write request will consume zxid.
    /// It should be same across all nodes.
    std::atomic<int64_t> zxid{0};

    /// finalized flag
    bool finalized{false};

    const String super_digest;

    Poco::Logger * log;
};

}
